import pandas as pd
import boto3
import botocore
import csv
import warnings
import os

from botocore import UNSIGNED
from botocore.client import Config
from botocore.handlers import disable_signing

warnings.simplefilter(action = 'ignore', category = FutureWarning)

SITES = ['Site-CBIC', 'Site-CUNY', 'Site-RU', 'Site-SI']
SCANS = ['anat', 'func', 'dwi', 'fmap']
TASKS_MAP = {
    'REST': 'task-rest',
    'REST1': 'task-rest_run-1',
    'REST2': 'task-rest_run-2',
    'PEER1': 'task-peer_run-1',
    'PEER2': 'task-peer_run-2',
    'PEER3': 'task-peer_run-3',
    'MOVIEDM': 'task-movieDM',
    'MOVIETP': 'task-movieTP'
}

def files(client, bucket, prefix = ''):
    """ Return the path to the participants.tsv file in the bucket """
    paginator = client.get_paginator('list_objects')
    for result in paginator.paginate(Bucket = bucket, Prefix = prefix, Delimiter = 'participants.tsv'):
        for prefix in result.get('CommonPrefixes', []):
            yield prefix.get('Prefix')

def generate_Subfolders(s3_client):
    """ Generates a list of the contents specified in the prefix """
    gen_subfolders = files(s3_client, 'fcp-indi', prefix = ('data/Projects/HBN/MRI/'))
    genSubfoldersList = list(gen_subfolders)
    print(genSubfoldersList)
    return genSubfoldersList

def download_data(out_dir,
                  aws_links = 0,
                  age_min = 0,
                  age_max = 100,
                  sex = '',
                  site = None,
                  scans = None,
                  tasks = None,
                  dry_run = False):
    """
    Function to download images from the Healthy Brain Network sirectory on FCP-INDI's S3 bucket

    Parameters:
    out_dir (string): filepath to a local directory to save the files
    aws_links (string): filepath of HBN_aws_links.csv from XXXXX
    age_min (float): minimum age (in years) for participants of interest
    age_max (float): maximum age (in years) for participants of interest
    sex (string): 'M' or 'F' to indicate whether to download male or female data
    site (list): the name of the sites to download the data (e.g. Site-SI, Site-CBIC)
    scans (list): the scan types to download (e.g. anat, dwi, fmap, func)
    series (list): the series to download (e.g. REST1, MOVIEDM)
    dry_run (boolean): whether or not to perform a dry run (i.e., no actual downloads, just listing files that would be downloaded)

    Returns:
    (boolean): returns TRUE if the download was successfull, FALSE otherwise
    """

    s3_bucket_name = 'fcp-indi'
    s3_prefix = 'data/Projects/HBN/MRI'

    s3 = boto3.resource('s3')
    s3.meta.client.meta.events.register('choose-signer.s3.*', disable_signing)

    if not os.path.exists(out_dir) and not dry_run:
        print(f'Could not find {out_dir}, creating now...')
        os.makedirs(out_dir)
    
    if os.path.isfile(aws_links):
        participants_df = pd.read_csv(aws_links, na_values = ['n/a'])
    else:
        os.system('wget url/to/HBN_aws_links.csv .')
        print('HBN_aws_links.csv saved to current folder')
        participants_df = list(csv.reader(aws_links, delimiter = ','))

    s3_client = boto3.client('s3', config = Config(signature_version = UNSIGNED))

    if site is None:
        site = SITES

    print('Getting images of interest...')
    if age_min:
        participants_df = participants_df[participants_df['age'] >= age_min]
    if age_max:
        participants_df = participants_df[participants_df['age'] <= age_max]
    if sex == 'M':
        participants_df = participants_df[participants_df['gender'] == 'M']
    elif sex == 'F':
        participants_df = participants_df[participants_df['gender'] == 'F']
    if scans:
        participants_df = participants_df[(participants_df['filepath'].str.contains('|'.join(scans)) == True)]
    if tasks:
        tasks = [s for s in tasks if s in TASKS_MAP.keys()]
        tasks_filt = [TASKS_MAP[s] for s in tasks]
        participants_df = participants_df[(participants_df['filepath'].str.contains('|'.join(tasks_filt)) == True)]

    participants_df = participants_df[participants_df['site'].isin(site)]
        
    if len(participants_df) == 0:
        print('No participants met the given criteria.')
        print('Check the input arguments and run this script again.')
        return
    
    participants_filt = ['sub-' + label + '/' for label in participants_df['subject'].tolist()]
    participants_filt = list(set(participants_filt))

    s3_keylist = list(participants_df['filepath'])
    s3_keylist = list(set(s3_keylist))

    new_participants_df = pd.DataFrame(participants_df)
    for row in s3_keylist:
        row_divided = row.split('/')
        just_sub = [x for x in row_divided if x.startswith('sub-')]

        if len(just_sub) > 0:
            ursi = just_sub[0][4:]
            participant_row = participants_df[participants_df.iloc[:, 0].str.contains(ursi)]
            new_participants_df = pd.concat([new_participants_df, participant_row])
            new_participants_df = pd.DataFrame.drop_duplicates(new_participants_df)
    participants_df = new_participants_df

    participants_filt = ['sub-' + label + '/' for label in participants_df['subject'].tolist()]
    total_files = len(s3_keylist)
    files_downloaded = len(s3_keylist)
    for path_idx, s3_path in enumerate(s3_keylist):
        s3_path = s3_path.replace('s3://fcp-indi/', '')
        rel_path = s3_path.replace(s3_prefix, '')
        rel_path = rel_path.lstrip('/')
        download_file = os.path.join(out_dir, rel_path)
        download_dir = os.path.dirname(download_file)
        
        if not os.path.exists(download_dir) and not dry_run:
            os.makedirs(download_dir)

        try:
            if not os.path.exists(download_file):
                if dry_run:
                    print(f'File would be downloaded to {download_file}')
                else:
                    print(f'Downloading files to {download_file}')
                    with open(download_file, 'wb') as f:
                        s3_client.download_fileobj(s3_bucket_name, s3_path, f)
                    print('%.3f%% percent completed' % (100 * (float(path_idx + 1) / total_files)))
            else:
                print(f'File {download_file} already exists, skipping...')
                files_downloaded -= 1
        except Exception as exc:
            print(f'There was a problem downloading {s3_path}.')
            print('Check input arguments and try again')
            print(exc)
    
    print(participants_df)

    if dry_run:
        print(f'{files_downloaded} files would be downloaded for {len(participants_df)} participants')
    else:
        print(f"{files_downloaded} files downloaded for {len(participants_df['subject'].unique())} participants")
    
    if not dry_run:
        print('Saving out revised participants.tsv file.')
        if os.path.isfile(os.path.join(out_dir, 'participants.tsv')):
            old_participants_df = pd.read_csv(os.path.join(out_dir, 'participants.tsv'), delimiter = '\t', na_values = ['n/a', 'N/A'])
            participants_df = participants_df.append(old_participants_df, ignore_index = True)
            participants_df.drop_duplicates(inplace = True)
            os.remove(os.path.join(out_dir, 'participants.tsv'))
        participants_df.to_csv(os.path.join(out_dir, 'participants.tsv'), sep = '\t', na_rep = 'n/a', index = False)
    

if __name__ == '__main__':
    import argparse
    import sys
    import os

    parser = argparse.ArgumentParser(description=__doc__)

    parser.add_argument('--out_dir', required = True, type = str,
                        help = 'Output directory where files will be saved')
    parser.add_argument('--aws_links', required = False, type = str,
                        help = 'Path to HBN_aws_links.csv. Leave it empty to let the script search and download to the current directory')
    parser.add_argument('--age_min', required = False, type = float,
                        help = 'Minimum participant age (in years) to download (e.g., for subjects 30 or older, \'-age_min 30\')')
    parser.add_argument('--age_max', required = False, type = float,
                        help = 'Maximum participant age (in years) to download (e.g., for subjects 30 or younger, \'-age_max 30\')')
    parser.add_argument('--sex', required = False, type = str,
                        help = 'Participant sex (e.g., M or F)')
    parser.add_argument('--site', required = False, nargs = '*', type = str,
                        help = 'A space-separated list of site names to download (e.g., Site-Si, Site-RU, Site-CBIC, Site-CUNY)')
    parser.add_argument('--scans', required = False, nargs = '*', type = str,
                        help = 'A space-separated list of scan types to download (e.g., anat, dwi, fmap, func)')
    parser.add_argument('--tasks', required = False, nargs = '*', type = str,
                        help = 'A space-separated list of tasks codes to download (e.g. REST1, REST2, MOVIEDM)')
    parser.add_argument('--dry_run', required = False, action = 'store_true', 
                        help = 'Dry run to see how many files would be downloaded')
    
    args = parser.parse_args()

    out_dir = os.path.abspath(args.out_dir)
    kwargs = {}
    if args.aws_links:
        kwargs['aws_links'] = args.aws_links
    elif os.path.exists('HBN_aws_links.csv'):
        print('Found aws_links.csv in the current working directory.')
        kwargs['aws_links'] = 'HBN_aws_links.csv'

    if args.age_min:
        kwargs['age_min'] = args.age_min
        print(f"Minimum age set to {kwargs['age_min']}")
    
    if args.age_max:
        kwargs['age_max'] = args.age_max
        print(f"Maximum age set to {kwargs['age_max']}")
    
    if args.sex:
        kwargs['sex'] = args.sex.upper()
        if kwargs['sex'] == 'M':
            print('Downloading only male participants')
        elif kwargs['sex'] == 'F':
            print('Downloading only female participants')
        else:
            print(f"Input for sex {kwargs['sex']} is not valid")
            print('Valid sex inputs: M, F')
            print('Check the script syntax and try again')
            sys.exit(1)

    if args.site:
        kwargs['site'] = args.site
        for site in kwargs['site']:
            if site not in SITES:
                print(f'Site {site} is not a valid site name.')
                print('Valid site names: Site-SI, Site-RU, Site-CBIC, Site-CUNY')
                print('Check the script syntax and try again')
                sys.exit(1)
        print('Sites to download: ' + ' '.join(kwargs['site']))

    if args.scans:
        kwargs['scans'] = args.scans
        for scan in kwargs['scans']:
            if scan not in SCANS:
                print(f'Scan {scan} is not a valid scan name')
                print('Valid scan names: anat, dwi, fmap, func')
                print('Check the script syntax and try again')
                sys.exit(1)
        print('Scans to download: ' + ' '.join(kwargs['scans']))
    
    if args.tasks:
        kwargs['tasks'] = args.tasks
        for task in kwargs['tasks']:
            if task not in TASKS_MAP.keys():
                print(f'Task {task} is not a valid task name')
                print('Valid task names: REST1, REST2, PEER1, PEER2, PEER3, MOVIEDM, MOVIETP')
                print('Check the script syntax and try again')
                sys.exit(1)
        print('Tasks to download: ' + ' '.join(kwargs['tasks']))
    
    if args.dry_run:
        kwargs['dry_run'] = args.dry_run
        print('Running download script as a dry run.')

    download_data(out_dir, **kwargs)